package com.wind.utils.helper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.function.Consumer;

/**
 * <p>
 *      jdbc工具类
 * </p>
 * @author wind
 * @date    2024/12/11 10:46
 * @version v1.0
 */
public class JdbcHelper {

    private static final Logger logger = LoggerFactory.getLogger(JdbcHelper.class);

    private final String driver;

    private final String user;

    private final String password;

    private final String jdbcUrl;

    public JdbcHelper(String driver, String user, String password, String jdbcUrl) {
        this.driver = driver;
        this.user = user;
        this.password = password;
        this.jdbcUrl = jdbcUrl;
    }

    /**
     * 获取数据库连接
     * @return
     */
    public Connection getConn(){
        try {
            Class.forName(driver);
            Properties props = new Properties();
            props.setProperty("user", user);
            if (password != null) {
                props.setProperty("password", password);
            }
            props.setProperty("useSSL", "false");
            props.setProperty("verifyServerCertificate", "false");
            return DriverManager.getConnection(jdbcUrl, props);
        } catch (ClassNotFoundException | SQLException e) {
            throw new RuntimeException(e.getMessage());
        }
    }

    /**
     * 获取数据库表的描述信息
     * @param tableName
     * @return
     * @throws Exception
     */
    public List<String> getTables(String tableName) {
        Connection con = null;
        String catalog = getDataBase();
        List<String> tables = new ArrayList<>();
        try {
            con = getConn();
            DatabaseMetaData db = con.getMetaData();
            // 当数据库为 impala时 schemaPattern 不能为空 要为数据库名称 才能查出当前数据库下的表 不然为所有数据下的表
            String schemaPattern = null;
            if (driver.contains("hive") || driver.contains("impala")) {
                schemaPattern = catalog;
            }else if(driver.contains("oracle")){
                schemaPattern = user;
                if(schemaPattern != null){
                    schemaPattern = schemaPattern.toUpperCase();
                }
            }
            ResultSet rs = db.getTables(catalog, schemaPattern, "%" + tableName + "%", new String[]{"TABLE"});
            while(rs.next()) {
                String name = rs.getString("TABLE_NAME");
                tables.add(name);
            }
        } catch (SQLException e) {
            logger.error("JdbcUtil.getTables failed, err is {}", e.getMessage());
        } finally {
            close(con, null, null);
        }
        return tables;
    }

    /**
     * 获取数据库列信息
     * @param tableName
     * @return
     */
    public List<String> getColumns(String tableName){
        List<String> columns = new ArrayList<>();
        String catalog = getDataBase();
        if(driver.contains("oracle")){
            if(tableName != null){
                tableName = tableName.toUpperCase();
            }
        }
        Connection con = null;
        try {
            con = getConn();
            DatabaseMetaData db = con.getMetaData();
            ResultSet rs = db.getColumns(catalog, "%", tableName, "%");
            while(rs.next()) {
                String colName = rs.getString("COLUMN_NAME");
                columns.add(colName);
            }
        } catch (SQLException e) {
            logger.error("JdbcUtil.getTables failed, err is {}", e.getMessage());
        } finally {
            close(con, null, null);
        }
        return columns;
    }

    /**
     * 查询
     * @param sql
     * @param consumer
     * @return
     */
    public boolean executeQuery(String sql, Consumer<ResultSet> consumer){
        Connection conn = getConn();
        Statement ps = null;
        ResultSet rs = null;
        int queryTimeout = 0;
        try {
            ps = conn.prepareStatement(sql);
            queryTimeout = ps.getQueryTimeout();
            rs = ps.executeQuery(sql);
            if (consumer != null) {
                consumer.accept(rs);
            }
            return true;
        } catch (SQLException e) {
            logger.error("JdbcHelper executeQuery failed, err is {}, queryTimeout is {}", e.getMessage(), queryTimeout);
            throw new RuntimeException(e);
        } finally {
            close(conn, ps, rs);
        }
    }

    /**
     * ddl语句
     * @param sql
     * @return
     */
    public boolean executeSql(String sql){
        Connection conn = getConn();
        Statement ps = null;
        try {
            ps = conn.prepareStatement(sql);
            return ps.execute(sql);
        } catch (SQLException e) {
            logger.error("db executeSql failed, err is {}", e.getMessage());
            throw new RuntimeException(e);
        } finally {
            close(conn, ps, null);
        }
    }

    /**
     * 测试
     * @param testSql
     * @return
     */
    public boolean test(String testSql){
        if (testSql == null) {
            // jdbc 默认检测连接语句
            testSql = "select 1";
        }
        if(driver.contains("oracle")){
            testSql = "select 1 from dual";
        }
        Connection conn = null;
        Statement ps = null;
        ResultSet rs = null;
        try {
            conn = getConn();
            ps = conn.createStatement();
            rs = ps.executeQuery(testSql);
            return true;
        } catch (SQLException e) {
            logger.error("test sql failed, err is {}", e.getMessage());
            throw new RuntimeException(e);
        } finally {
            close(conn, ps, rs);
        }
    }

    /**
     * 可以反射拿database名称，也可以通过url截取，该方法通过url截取
     * @return
     */
    public String getDataBase() {
        if (jdbcUrl == null) {
            return "";
        }
        int index = jdbcUrl.indexOf('?');
        if(index > -1){
            String s = jdbcUrl.split("\\?")[0];
            return s.substring(s.lastIndexOf('/') + 1);
        }
        return jdbcUrl.substring(jdbcUrl.lastIndexOf('/') + 1);
    }

    /**
     * 关闭jdbc连接
     *
     * @param con
     * @param ps
     * @param rs
     */
    public static void close(Connection con, Statement ps, ResultSet rs) {
        if (rs != null) {
            try {
                rs.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if (ps != null) {
            try {
                ps.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if (con != null) {
            try {
                con.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}
